package com.ikingtech.framework.sdk.gray;

import lombok.RequiredArgsConstructor;
import org.springframework.beans.factory.ObjectProvider;
import org.springframework.cloud.client.ServiceInstance;
import org.springframework.cloud.client.loadbalancer.DefaultResponse;
import org.springframework.cloud.client.loadbalancer.EmptyResponse;
import org.springframework.cloud.client.loadbalancer.Request;
import org.springframework.cloud.client.loadbalancer.Response;
import org.springframework.cloud.loadbalancer.core.NoopServiceInstanceListSupplier;
import org.springframework.cloud.loadbalancer.core.ReactorServiceInstanceLoadBalancer;
import org.springframework.cloud.loadbalancer.core.ServiceInstanceListSupplier;
import org.springframework.http.HttpHeaders;
import reactor.core.publisher.Mono;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ThreadLocalRandom;

import static com.ikingtech.framework.sdk.context.constant.SecurityConstants.HEADER_VERSION;

/**
 * @author tie yan
 */
@RequiredArgsConstructor
public class GrayLoadBalancer implements ReactorServiceInstanceLoadBalancer {

    private final ObjectProvider<ServiceInstanceListSupplier> serviceInstanceListSupplierProvider;

    @Override
    public Mono<Response<ServiceInstance>> choose(Request request) {
        HttpHeaders headers = (HttpHeaders) request.getContext();
        ServiceInstanceListSupplier supplier = serviceInstanceListSupplierProvider
                .getIfAvailable(NoopServiceInstanceListSupplier::new);
        return supplier.get().next().map(x -> this.getInstanceResponse(x, headers));
    }



    private Response<ServiceInstance> getInstanceResponse(List<ServiceInstance> instances, HttpHeaders headers) {
        if (instances.isEmpty()) {
            return new EmptyResponse();
        } else {
            String versionNo = Optional.ofNullable(headers.getFirst(HEADER_VERSION)).orElse("default");
            List<ServiceInstance> matchedInstances = new ArrayList<>();
            for (ServiceInstance instance : instances) {
                Map<String,String> metadata = instance.getMetadata();
                if(metadata.containsValue(versionNo)){
                    matchedInstances.add(instance);
                    break;
                }
            }

            if(matchedInstances.isEmpty()){
                return new DefaultResponse(instances.get(ThreadLocalRandom.current().nextInt(instances.size())));
            }
            return new DefaultResponse(matchedInstances.get(ThreadLocalRandom.current().nextInt(matchedInstances.size())));
        }
    }
}
